/*
 * Copyright 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package androidx.room.solver.query.result

import androidx.room.compiler.codegen.CodeLanguage
import androidx.room.compiler.codegen.XCodeBlock
import androidx.room.compiler.codegen.XCodeBlock.Builder.Companion.applyTo
import androidx.room.compiler.codegen.XTypeName
import androidx.room.compiler.codegen.buildCodeBlock
import androidx.room.compiler.processing.XNullability
import androidx.room.compiler.processing.XType
import androidx.room.ext.CommonTypeNames
import androidx.room.ext.CommonTypeNames.ARRAY_LIST
import androidx.room.ext.CommonTypeNames.HASH_SET
import androidx.room.ext.KotlinCollectionMemberNames
import androidx.room.ext.KotlinTypeNames
import androidx.room.solver.CodeGenScope
import androidx.room.solver.query.result.MultimapQueryResultAdapter.MapType.Companion.isSparseArray
import androidx.room.vo.ColumnIndexVar

/**
 * This is an intermediary adapter class that enables nested multimap return types in DAOs.
 *
 * The [MapValueResultAdapter] sealed class is extended by 2 classes, [NestedMapValueResultAdapter]
 * and [EndMapValueResultAdapter]. These adapters are wrappers for the adapters at different levels
 * of nested maps. Each level of nesting of a map is represented by a [NestedMapValueResultAdapter],
 * except the innermost level which is represented by an [EndMapValueResultAdapter].
 *
 * For example, if a DAO function returns a `Map<A, Map<B, Map<C, D>>>`, `Map<C, D>` is represented
 * by an [EndMapValueResultAdapter], and the outer 2 levels are represented by a
 * [NestedMapValueResultAdapter] each.
 *
 * A [NestedMapValueResultAdapter] can wrap either another [NestedMapValueResultAdapter] or an
 * [EndMapValueResultAdapter], whereas an [EndMapValueResultAdapter] does not wrap another adapter
 * and only contains row adapters for the innermost map.
 */
sealed class MapValueResultAdapter(val rowAdapters: List<RowAdapter>) {

    /** True if this adapters requires key checking due to its values being passed by reference. */
    abstract fun requiresContainsKeyCheck(): Boolean

    /** Left-Hand-Side of a Map value type arg initialization. */
    abstract fun getDeclarationTypeName(): XTypeName

    /** Right-Hand-Side of a Map value type arg initialization. */
    abstract fun getInstantiationCodeBlock(): XCodeBlock

    abstract fun convert(
        scope: CodeGenScope,
        valuesVarName: String,
        stmtVarName: String,
        dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
        addPutValueCode: XCodeBlock.Builder.(String, Boolean) -> Unit = { _, _ -> },
    )

    abstract fun generateContinueColumnCheck(
        scope: CodeGenScope,
        stmtVarName: String,
        dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
    )

    /**
     * A [NestedMapValueResultAdapter] contains the key information and the value map information of
     * any level of a nested map that is not the innermost "End" map.
     *
     * The [convert] function implementation for a [NestedMapValueResultAdapter] generates code that
     * resolves the key of the map and delegates to the value map's [NestedMapValueResultAdapter] or
     * [EndMapValueResultAdapter] (based on the level of nesting) to resolve the value map
     * conversion.
     */
    class NestedMapValueResultAdapter(
        private val keyRowAdapter: RowAdapter,
        private val keyTypeArg: XType,
        private val mapType: MultimapQueryResultAdapter.MapType,
        private val mapValueResultAdapter: MapValueResultAdapter,
    ) :
        MapValueResultAdapter(
            rowAdapters = listOf(keyRowAdapter) + mapValueResultAdapter.rowAdapters
        ) {

        private val keyTypeName = keyTypeArg.asTypeName()

        override fun requiresContainsKeyCheck(): Boolean = true

        override fun getDeclarationTypeName() =
            when (val typeOfMap = this.mapType) {
                MultimapQueryResultAdapter.MapType.DEFAULT,
                MultimapQueryResultAdapter.MapType.ARRAY_MAP ->
                    typeOfMap.className.parametrizedBy(
                        keyTypeName,
                        mapValueResultAdapter.getDeclarationTypeName(),
                    )
                MultimapQueryResultAdapter.MapType.LONG_SPARSE,
                MultimapQueryResultAdapter.MapType.INT_SPARSE ->
                    typeOfMap.className.parametrizedBy(
                        mapValueResultAdapter.getDeclarationTypeName()
                    )
            }

        override fun getInstantiationCodeBlock(): XCodeBlock =
            when (val typeOfMap = this.mapType) {
                MultimapQueryResultAdapter.MapType.DEFAULT ->
                    // LinkedHashMap is used as impl to preserve key ordering for ordered
                    // query results.
                    buildCodeBlock { language ->
                        add(
                            XCodeBlock.ofNewInstance(
                                when (language) {
                                    CodeLanguage.JAVA -> CommonTypeNames.LINKED_HASH_MAP
                                    CodeLanguage.KOTLIN -> KotlinTypeNames.LINKED_HASH_MAP
                                }.parametrizedBy(
                                    keyTypeName,
                                    mapValueResultAdapter.getDeclarationTypeName(),
                                )
                            )
                        )
                    }
                MultimapQueryResultAdapter.MapType.ARRAY_MAP ->
                    XCodeBlock.ofNewInstance(
                        typeOfMap.className.parametrizedBy(
                            keyTypeName,
                            mapValueResultAdapter.getDeclarationTypeName(),
                        )
                    )
                MultimapQueryResultAdapter.MapType.LONG_SPARSE,
                MultimapQueryResultAdapter.MapType.INT_SPARSE ->
                    XCodeBlock.ofNewInstance(
                        typeOfMap.className.parametrizedBy(
                            mapValueResultAdapter.getDeclarationTypeName()
                        )
                    )
            }

        override fun convert(
            scope: CodeGenScope,
            valuesVarName: String,
            stmtVarName: String,
            dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
            addPutValueCode: XCodeBlock.Builder.(String, Boolean) -> Unit,
        ) {
            scope.builder.apply {
                // Read map key
                val tmpKeyVarName = scope.getTmpVar("_key")
                addLocalVariable(tmpKeyVarName, keyTypeArg.asTypeName())
                keyRowAdapter.convert(tmpKeyVarName, stmtVarName, scope)

                // Generate map key check if the next value adapter is by reference
                // (nested map case or collection end value)
                @Suppress("NAME_SHADOWING") // On purpose to avoid miss using param
                val valuesVarName =
                    if (mapValueResultAdapter.requiresContainsKeyCheck()) {
                        scope.getTmpVar("_values").also { tmpValuesVarName ->
                            addLocalVariable(
                                tmpValuesVarName,
                                mapValueResultAdapter.getDeclarationTypeName(),
                            )
                            if (mapType.isSparseArray()) {
                                    beginControlFlow(
                                        "if (%L.get(%L) != null)",
                                        valuesVarName,
                                        tmpKeyVarName,
                                    )
                                } else {
                                    beginControlFlow(
                                        "if (%L.containsKey(%L))",
                                        valuesVarName,
                                        tmpKeyVarName,
                                    )
                                }
                                .applyTo { language ->
                                    val getFunction =
                                        when (language) {
                                            CodeLanguage.JAVA -> "get"
                                            CodeLanguage.KOTLIN ->
                                                if (mapType.isSparseArray()) "get" else "getValue"
                                        }
                                    addStatement(
                                        "%L = %L.%L(%L)",
                                        tmpValuesVarName,
                                        valuesVarName,
                                        getFunction,
                                        tmpKeyVarName,
                                    )
                                }
                                .nextControlFlow("else")
                                .apply {
                                    addStatement(
                                        "%L = %L",
                                        tmpValuesVarName,
                                        mapValueResultAdapter.getInstantiationCodeBlock(),
                                    )
                                    addStatement(
                                        "%L.put(%L, %L)",
                                        valuesVarName,
                                        tmpKeyVarName,
                                        tmpValuesVarName,
                                    )
                                }
                                .endControlFlow()

                            // Perform key columns null check, in a nested mapping we still add
                            // the key with an empty map as the value entry.
                            mapValueResultAdapter.generateContinueColumnCheck(
                                scope,
                                stmtVarName,
                                dupeColumnsIndexAdapter,
                            )
                        }
                    } else {
                        valuesVarName
                    }
                @Suppress("NAME_SHADOWING") // On purpose, to avoid using param
                val addPutValueCode: XCodeBlock.Builder.(String, Boolean) -> Unit =
                    { tmpValueVarName, doKeyCheck ->
                        if (doKeyCheck) {
                            // For consistency purposes, in the one-to-one object mapping case, if
                            // multiple values are encountered for the same key, we will only
                            // consider the first ever encountered mapping.
                            if (mapType.isSparseArray()) {
                                    beginControlFlow(
                                        "if (%L.get(%L) == null)",
                                        valuesVarName,
                                        tmpKeyVarName,
                                    )
                                } else {
                                    beginControlFlow(
                                        "if (!%L.containsKey(%L))",
                                        valuesVarName,
                                        tmpKeyVarName,
                                    )
                                }
                                .apply {
                                    addStatement(
                                        "%L.put(%L, %L)",
                                        valuesVarName,
                                        tmpKeyVarName,
                                        tmpValueVarName,
                                    )
                                }
                                .endControlFlow()
                        } else {
                            addStatement(
                                "%L.put(%L, %L)",
                                valuesVarName,
                                tmpKeyVarName,
                                tmpValueVarName,
                            )
                        }
                    }
                mapValueResultAdapter.convert(
                    scope = scope,
                    valuesVarName = valuesVarName,
                    stmtVarName = stmtVarName,
                    dupeColumnsIndexAdapter = dupeColumnsIndexAdapter,
                    addPutValueCode = addPutValueCode,
                )
            }
        }

        override fun generateContinueColumnCheck(
            scope: CodeGenScope,
            stmtVarName: String,
            dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
        ) {
            scope.builder.add(
                getContinueColumnNullCheck(
                    stmtVarName = stmtVarName,
                    rowAdapter = keyRowAdapter,
                    dupeColumnsIndexAdapter = dupeColumnsIndexAdapter,
                )
            )
        }
    }

    /**
     * An [EndMapValueResultAdapter] contains only the value information regarding the innermost map
     * of the returned nested map.
     *
     * The [convert] function implementation for an [EndMapValueResultAdapter] uses the value row
     * adapter to innermost value map's value, regardless of whether it is a collection type or not.
     */
    class EndMapValueResultAdapter(
        private val valueRowAdapter: RowAdapter,
        private val valueTypeArg: XType,
        private val valueCollectionType: MultimapQueryResultAdapter.CollectionValueType?,
    ) : MapValueResultAdapter(rowAdapters = listOf(valueRowAdapter)) {
        override fun requiresContainsKeyCheck(): Boolean = valueCollectionType != null

        // The type name of the concrete result map value
        // For Map<Foo, Bar> it is Bar
        // For Map<Foo, List<Bar> it is ArrayList<Bar>
        override fun getDeclarationTypeName(): XTypeName {
            return valueCollectionType?.className?.parametrizedBy(valueTypeArg.asTypeName())
                ?: valueTypeArg.asTypeName()
        }

        // The type name of the result map value
        // For Map<Foo, Bar> it is Bar
        // for Map<Foo, List<Bar> it is List<Bar>
        override fun getInstantiationCodeBlock(): XCodeBlock {
            return when (valueCollectionType) {
                MultimapQueryResultAdapter.CollectionValueType.LIST ->
                    buildCodeBlock { language ->
                        when (language) {
                            CodeLanguage.JAVA ->
                                add(
                                    "new %T()",
                                    ARRAY_LIST.parametrizedBy(valueTypeArg.asTypeName()),
                                )
                            CodeLanguage.KOTLIN ->
                                add("%M()", KotlinCollectionMemberNames.MUTABLE_LIST_OF)
                        }
                    }
                MultimapQueryResultAdapter.CollectionValueType.SET ->
                    buildCodeBlock { language ->
                        when (language) {
                            CodeLanguage.JAVA ->
                                add("new %T()", HASH_SET.parametrizedBy(valueTypeArg.asTypeName()))
                            CodeLanguage.KOTLIN ->
                                add("%M()", KotlinCollectionMemberNames.MUTABLE_SET_OF)
                        }
                    }
                else -> XCodeBlock.ofNewInstance(valueTypeArg.asTypeName())
            }
        }

        override fun convert(
            scope: CodeGenScope,
            valuesVarName: String,
            stmtVarName: String,
            dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
            addPutValueCode: XCodeBlock.Builder.(String, Boolean) -> Unit,
        ) {
            scope.builder.apply {
                val tmpValueVarName = scope.getTmpVar("_value")

                // If we have a collection type, then this means that we have a 1-to-many mapping
                // as opposed to a 1-to-many mapping.
                if (valueCollectionType != null) {
                    addLocalVariable(tmpValueVarName, valueTypeArg.asTypeName())
                    valueRowAdapter.convert(tmpValueVarName, stmtVarName, scope)
                    addStatement("%L.add(%L)", valuesVarName, tmpValueVarName)
                } else {
                    check(valueRowAdapter is QueryMappedRowAdapter)
                    val valueIndexVars =
                        dupeColumnsIndexAdapter?.getIndexVarsForMapping(valueRowAdapter.mapping)
                            ?: valueRowAdapter.getDefaultIndexAdapter().getIndexVars()
                    val columnNullCheckCodeBlock =
                        getColumnNullCheckCode(
                            stmtVarName = stmtVarName,
                            indexVars = valueIndexVars,
                        )

                    // Perform value columns null check, in a 1-to-1 mapping we still add the key
                    // with a null value entry if permitted.
                    beginControlFlow("if (%L)", columnNullCheckCodeBlock)
                        .applyTo { language ->
                            if (
                                language == CodeLanguage.KOTLIN &&
                                    valueTypeArg.nullability == XNullability.NONNULL
                            ) {
                                addStatement(
                                    "error(%S)",
                                    "The column(s) of the map value object of type " +
                                        "'$valueTypeArg' are NULL but the map's value type " +
                                        "argument expect it to be NON-NULL",
                                )
                            } else {
                                addPutValueCode("null", false)
                                addStatement("continue")
                            }
                        }
                        .endControlFlow()

                    addLocalVariable(tmpValueVarName, valueTypeArg.asTypeName())
                    valueRowAdapter.convert(tmpValueVarName, stmtVarName, scope)
                    addPutValueCode(tmpValueVarName, true)
                }
            }
        }

        override fun generateContinueColumnCheck(
            scope: CodeGenScope,
            stmtVarName: String,
            dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
        ) {
            scope.builder.add(
                getContinueColumnNullCheck(
                    stmtVarName = stmtVarName,
                    rowAdapter = valueRowAdapter,
                    dupeColumnsIndexAdapter = dupeColumnsIndexAdapter,
                )
            )
        }
    }

    /**
     * Utility function that returns a code block containing the code expression that verifies if
     * all matched properties are null.
     */
    protected fun getContinueColumnNullCheck(
        rowAdapter: RowAdapter,
        stmtVarName: String,
        dupeColumnsIndexAdapter: AmbiguousColumnIndexAdapter?,
    ) =
        XCodeBlock.builder()
            .apply {
                check(rowAdapter is QueryMappedRowAdapter)
                val valueIndexVars =
                    dupeColumnsIndexAdapter?.getIndexVarsForMapping(rowAdapter.mapping)
                        ?: rowAdapter.getDefaultIndexAdapter().getIndexVars()
                val columnNullCheckCodeBlock =
                    getColumnNullCheckCode(stmtVarName = stmtVarName, indexVars = valueIndexVars)
                beginControlFlow("if (%L)", columnNullCheckCodeBlock)
                    .apply { addStatement("continue") }
                    .endControlFlow()
            }
            .build()

    /** Generates a code expression that verifies if all matched properties are null. */
    protected fun getColumnNullCheckCode(stmtVarName: String, indexVars: List<ColumnIndexVar>) =
        buildCodeBlock { language ->
            val space =
                when (language) {
                    CodeLanguage.JAVA -> "%W"
                    CodeLanguage.KOTLIN -> " "
                }
            val conditions =
                indexVars.map { XCodeBlock.of("%L.isNull(%L)", stmtVarName, it.indexVar) }
            val placeholders = conditions.joinToString(separator = "$space&&$space") { "%L" }
            add(placeholders, *conditions.toTypedArray())
        }
}
